{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cdd75b2b-e241-4b2a-a65c-5a089bcdaf07",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image size is (28, 28)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlEAAAJnCAYAAAC3XwZuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAsdElEQVR4nO3dfXBV5Z3A8V8gECyQbChFYMWkFaq7Hda4ULvrS9EKaHXGMlpro1KsQF/G7bSwK1ssAqIuyk61L07HmaUSVtY6tqW2dYpLbY0IFaxVqFanRiyIMLoqITegRNzc/aND1pQQkoeT3CR8PjN3pjn33Oc+ucmtX55zcm5RPp/PBwAAndKv0BMAAOiNRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFNAjLV68OIqKiuKcc84pyPMXFRVFUVFR1NbWFuT5gZ5PRAFHdDBoioqKCj2VPuHga9mR27nnnlvo6QKHUVzoCQAca44//vh27z9w4EDs3r07IiI++tGPdseUgAQiCqCbvfrqq+3e/81vfjP+5V/+JSIiZs6c2R1TAhI4nAfQw3z/+9+PiIizzjorTj755ALPBjgcEQV0mYaGhrjvvvviyiuvjPHjx8ewYcNi0KBBUVFREVdccUVs3Lixw2Pdf//9MWnSpBg2bFgMHjw4JkyYEHfeeWf87//+7xHncMstt8THPvaxKC8vj5KSkhgzZkxUV1d36vm7y29+85t4/vnnIyJi1qxZBZ4N0B4RBXSZO+64I6qrq+Pee++NZ599Ng4cOBARES+//HL84Ac/iDPOOCO+853vHHGcf/3Xf43LL788HnvssYiI2L9/fzz11FPxla98JS666KJoampq83GbNm2Kk08+ORYsWBBPPPFENDY2RklJSbzyyitx3333xRlnnBFLly7t9Pf13hPtt23b1unHt+fgKlRpaWlcdtllmY4NZEtEAV1m5MiRMWfOnNi4cWPU19dHY2NjvP322/HSSy/FV7/61YiImDt3bjz99NOHHWPz5s2xbNmy+Kd/+qd47bXXYvfu3VFfXx833XRTFBUVxX//93/H/PnzD3nctm3b4oILLojXXnstPv3pT8fvfve72L9/f+RyuXjttdfihhtuiP79+8f1118fDzzwQFe9BJ2yd+/euP/++yMi4oorroj3ve99BZ4R0K48wBEsWrQoHxH5rP8v49prr81HRH7mzJntPuf06dPbfPyCBQvyEZEvLi7O79y5s9V9n/70p9t9bD6fz99+++35iMifeuqph9x38LkfeeSRduf2pz/9qd3vsTP+4z/+o2XcJ598MrNxga5hJQoomIsuuigiItavX9/ufgsXLmxz+3XXXRfHHXdcvPvuu/HjH/+4Zfvu3btj9erVERHx9a9//bDjfu5zn4uIiC1btsRrr73W4XkvXrw48vl85PP5qKys7PDjjmT58uUREXHqqafGhAkTMhsX6BoucQB0qZdeeim+973vxSOPPBJbt26NxsbGaG5ubrXPK6+8ctjHjxkzJsaOHdvmfaWlpTFhwoRYv359PPnkky3bH3/88Zbn+MQnPtGheW7fvv2I12/qSn/4wx9i06ZNEeGEcugtRBTQZX7yk59EdXV1qxO/S0tLY9CgQVFUVBTvvPNO1NfXx759+w47xl//9V+3+xwH7/+f//mflm27du1q+d8dXWF66623OrRfVzm4CjVo0KC48sorCzoXoGMczgO6xJtvvhlXX311NDU1xSc+8Ymora2Nt956KxoaGuK1116LV199NX74wx8ecZyUj5o5eNmD4447ruWw25FuhfqMvoiId955J1atWhUREZdeemmUl5cXbC5Ax1mJArrEL37xi8jlclFeXh4///nP2/xLsyNduTui/UN9ERE7d+6MiIgRI0a0bBs5cmRERLz99tvx4osvHvZwYE/x05/+NN54442IcCgPehMrUUCX2LFjR0REnHzyyYf9U/2HH364Q+Ns3bq1zfsaGxvjd7/7XURETJw4sWX7GWec0bKCdd9993Vq3oVw8FDe2LFjY9KkSQWeDdBRIgroEmVlZRER8cILL8T+/fsPuX/z5s1x7733dmism266qc3t3/zmN+Ptt9+O4uLiuOSSS1q2jxgxIj71qU9FRMS///u/xwsvvNDu+Ac/7LcQXn755ZaYvOaaa5IOXwKFIaKATnnjjTfave3ZsyciIqZOnRr9+vWL3bt3x5VXXtly2O2dd96J+++/P6ZOnRpDhw494vOVlZXFypUr46tf/WrLIa/Gxsb4t3/7t5a4uvbaaw85Af2b3/xmvP/9749cLhdnnXVW3H333dHQ0NDq+1i9enVccsklUV1d3anXIMsrlt99993R3NwcxcXFcfXVVx/VWED3ElFAp3zgAx9o93bwBO1x48bFddddFxERq1evjhNOOCH+6q/+KoYMGRKXX355DBkypEMf+VJVVRXz5s2L73znO3H88cfH+9///igvL49vfOMb0dzcHJMnT45bb731kMd96EMfil/+8pdRWVkZr7/+esycOTPKy8tj2LBhMXTo0PjABz4Ql156afzkJz855JIL3aW5uTlqamoiIuLCCy+MUaNGFWQeQBoRBXSZW2+9Nf7zP/8zTj/99DjuuOPiwIEDMXbs2Lj++uvj6aefjtGjR3donNtuuy3uu+++OPPMM6O5uTkGDhwYVVVV8e1vfzseeuihGDRoUJuPO+200+K5556LO++8MyZPnhzDhw9vuU7VuHHj4oorroj77ruv5cKc3e3hhx+O7du3R4QTyqE3Ksrn8/lCTwIAoLexEgUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJOhVH0Dc3Nwcu3btiqFDh/poBAAgc/l8PhobG2P06NHRr1/7a029KqJ27doVY8aMKfQ0AIA+bseOHXHCCSe0u0+vOpzXkc/ZAgA4Wh1pjl4VUQ7hAQDdoSPN0W0R9dvf/jYuvPDCKC8vj8GDB8fpp58e9957b3c9PQBAprrlnKja2to4//zzY+DAgfHZz342ysrKYvXq1XHllVfGtm3b4vrrr++OaQAAZKbLP4D43XffjVNOOSVeeeWVePzxx+O0006LiIjGxsb4x3/8x/jjH/8Yzz33XIwbN+6IY+VyuSgrK+vK6QIARENDQ5SWlra7T5cfzvv1r38dW7dujSuuuKIloCL+fMLWDTfcEO+++26sWLGiq6cBAJCpLo+o2traiIiYOnXqIfcd3Pboo4929TQAADLV5RFVV1cXEdHm4bry8vIYPnx4yz4AAL1Fl59Y3tDQEBFx2HOZSktL45VXXmnzvqampmhqamr5OpfLZT9BAIAEPfo6UUuXLo2ysrKWm6uVAwA9RZdH1MEVqIMrUn+pvb+4mz9/fjQ0NLTcduzY0WXzBADojC6PqIPnQrV13lN9fX288cYbh728QUlJSZSWlra6AQD0BF0eUZMmTYqIiLVr1x5y38FtB/cBAOgtuuVimyeffHLs3LkzNm7cGFVVVRHR+mKbf/jDH+LDH/7wEcdysU0AoDt05GKbXf7XecXFxbF8+fI4//zz4+yzz47q6uooLS2N1atXx5/+9Ke4+eabOxRQAAA9SZevRB30xBNPxKJFi+Lxxx+Pd955Jz7ykY/E1772tbjyyis7PIaVKACgO3RkJarbIioLIgoA6A494rPzAAD6IhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkKBbIqqysjKKioravH3pS1/qjikAAGSquLueqKysLL72ta8dsn3ixIndNQUAgMwU5fP5fFc/SWVlZUREbNu27ajGyeVyUVZWdvQTAgBoR0NDQ5SWlra7j3OiAAASdNvhvKampli5cmXs3LkzysvL44wzzohTTz21u54eACBT3RZRr776alx99dWttl1wwQVxzz33xPDhw9t8TFNTUzQ1NbV8ncvlunKKAAAd1i2H86655pqora2N119/PXK5XGzcuDE++clPxkMPPRQXX3xxHO60rKVLl0ZZWVnLbcyYMd0xXQCAI+qWE8vb0tzcHJMmTYr169fHgw8+GBdddNEh+7S1EiWkAICu1qNPLO/Xr198/vOfj4iIDRs2tLlPSUlJlJaWtroBAPQEBf3rvIPnQr311luFnAYAQKcVNKI2bdoUEf9/HSkAgN6iyyPqueeeiz179hyyff369XH77bdHSUlJXHLJJV09DQCATHX5JQ7uv//+WLZsWZx33nlRWVkZJSUl8eyzz8batWujX79+cdddd8WJJ57Y1dMAAMhUl0fUueeeG88//3w89dRT8eijj8b+/fvj+OOPj8svvzzmzJkTp59+eldPAQAgcwW7xEEKn50HAHSHHn2JAwCA3kxEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQILiQk+AY9enP/3pTMebPXt2ZmPt2rUrs7H279+f2Vj/9V//ldlYERGvvvpqZmO9+OKLmY0F0BtYiQIASCCiAAASiCgAgAQiCgAggYgCAEggogAAEogoAIAEIgoAIIGIAgBIIKIAABKIKACABCIKACCBiAIASCCiAAASiCgAgAQiCgAggYgCAEggogAAEhTl8/l8oSfRUblcLsrKygo9DTLy0ksvZTpeZWVlpuMdCxobGzMb6w9/+ENmY8HhvPLKK5mNtWzZsszGevLJJzMbi56hoaEhSktL293HShQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkKC40BPg2DV79uxMx/u7v/u7zMZ6/vnnMxvrb/7mbzIb6+///u8zGysi4pxzzslsrH/4h3/IbKwdO3ZkNtaYMWMyG6sne/fddzMb6/XXX89srIiIUaNGZTpeVl5++eXMxnryySczG4vew0oUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJCguNAT4Nj1q1/9qkePl5WHHnqo0FM4rPLy8szGqqqqymys3/3ud5mN9dGPfjSzsXqy/fv3ZzbWCy+8kNlYERHPP/98ZmMNGzYss7G2bt2a2Vgcm6xEAQAkEFEAAAlEFABAgk5H1KpVq+KLX/xiTJw4MUpKSqKoqChqamoOu38ul4u5c+dGRUVFlJSUREVFRcydOzdyudzRzBsAoKA6fWL5ggULYvv27TF8+PAYNWpUbN++/bD77tu3LyZNmhSbN2+OKVOmRHV1dWzZsiXuuOOOeOSRR2L9+vUxePDgo/oGAAAKodMrUcuXL49t27bF66+/Hl/60pfa3XfZsmWxefPmmDdvXqxduzZuvfXWWLNmTSxcuDA2b94cy5YtS544AEAhdTqiJk+eHBUVFUfcL5/Px/Lly2PIkCGxcOHCVvfNnz8/ysvL4/vf/37k8/nOTgEAoOC67MTyurq62LVrV5x55pmHHLIbNGhQfPzjH4+dO3fGiy++2FVTAADoMl0aURER48aNa/P+g9sP7teWpqamyOVyrW4AAD1Bl0VUQ0NDRESUlZW1eX9paWmr/dqydOnSKCsra7mNGTMm+4kCACTo0deJmj9/fjQ0NLTcduzYUegpAQBERBd+dt7BFajDrTQdPDR3uJWqiIiSkpIoKSnJfnIAAEepy1aijnTO05HOmQIA6Mm6NKJGjx4dGzZsiH379rW6b//+/bFu3boYPXp0jB07tqumAADQZbosooqKimLWrFmxd+/eWLJkSav7li5dGvX19TFr1qwoKirqqikAAHSZTp8TtXz58li/fn1ERDzzzDMt22prayMiYtq0aTFt2rSIiJg3b1787Gc/i2XLlsXTTz8dEyZMiC1btsSaNWuiqqoq5s2bl813AQDQzTodUevXr4+VK1e22rZhw4bYsGFDRERUVla2RNTgwYOjtrY2brzxxvjRj34UtbW1MXLkyJgzZ04sWrTI5+YBAL1WpyOqpqYmampqOrx/WVlZ3H777XH77bd39qkAAHqsHn2dKACAnqoo34s+ATiXy7V7XSkAjt6ll16a6Xj3339/ZmM9++yzmY117rnnZjbW7t27MxuLnqGhoaHl01UOx0oUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJCguNATAODojRgxIrOxvve972U2VkREv37Z/Xt9yZIlmY21e/fuzMbi2GQlCgAggYgCAEggogAAEogoAIAEIgoAIIGIAgBIIKIAABKIKACABCIKACCBiAIASCCiAAASiCgAgAQiCgAggYgCAEggogAAEogoAIAEIgoAIIGIAgBIUFzoCQBw9K699trMxvrABz6Q2VgREfX19ZmN9cc//jGzseBoWYkCAEggogAAEogoAIAEIgoAIIGIAgBIIKIAABKIKACABCIKACCBiAIASCCiAAASiCgAgAQiCgAggYgCAEggogAAEogoAIAEIgoAIIGIAgBIIKIAABIUF3oCAMeqM888M7Oxvv71r2c2VtamTZuW2VjPPvtsZmPB0bISBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAmKCz0BgGPVhRdemNlYAwYMyGysX/3qV5mNFRHx+OOPZzoe9BRWogAAEogoAIAEnY6oVatWxRe/+MWYOHFilJSURFFRUdTU1LS57+LFi6OoqKjN26BBg4527gAABdPpc6IWLFgQ27dvj+HDh8eoUaNi+/btR3zMjBkzorKysvUTFzsdCwDovTpdMsuXL49x48ZFRUVF3HrrrTF//vwjPubqq6+Oc845J2V+AAA9UqcjavLkyV0xDwCAXqVbjqk99thj8cQTT0T//v3jlFNOicmTJ0dJSUl3PDUAQJfolohauHBhq69HjRoVK1eujClTpnTH0wMAZK5LL3FQVVUVK1eujG3btsXbb78ddXV1cdNNN8WePXvi4osvji1btrT7+Kampsjlcq1uAAA9QZdG1LRp0+Jzn/tcVFRUxKBBg2Ls2LGxYMGC+Pa3vx379++Pm2++ud3HL126NMrKylpuY8aM6crpAgB0WEEutjljxowoLi6ODRs2tLvf/Pnzo6GhoeW2Y8eObpohAED7CnKxpoEDB8bQoUPjrbfeane/kpISJ6ADAD1SQVai6urqor6+/pALcAIA9BZdFlGNjY3x+9///pDt9fX1MXPmzIiIqK6u7qqnBwDoUklXLF+/fn1ERDzzzDMt22prayPizyeTT5s2Ld5888049dRTY+LEiTF+/PgYMWJE7Ny5M9asWRNvvvlmTJkyJebMmZPddwIA0I06HVHr16+PlStXttq2YcOGlpPEKysrY9q0aTFs2LC49tprY+PGjfHzn/889uzZE4MHD47x48fHVVddFbNmzYr+/ftn810AAHSzTkdUTU1N1NTUHHG/0tLSuPPOO1PmBADQ4xXkxHIAgN6uIJc4AOitjjvuuMzGuuCCCzIb65133slsrEWLFmU2VkTEgQMHMh0PegorUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAguJCTwCgN7nuuusyG+u0007LbKyHHnoos7F+85vfZDYW9GVWogAAEogoAIAEIgoAIIGIAgBIIKIAABKIKACABCIKACCBiAIASCCiAAASiCgAgAQiCgAggYgCAEggogAAEogoAIAEIgoAIIGIAgBIIKIAABKIKACABMWFngBAV7rooosyHe+GG27IbKxcLpfZWEuWLMlsLKBjrEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAguJCTwDgL73//e/PbKzvfOc7mY0VEdG/f//MxvrFL36R2VgbN27MbCygY6xEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJigs9AaBv6N+/f2ZjPfTQQ5mN9cEPfjCzsSIitm7dmtlYN9xwQ2ZjAd3PShQAQAIRBQCQoFMRtXPnzvjWt74VU6dOjRNPPDEGDhwYI0eOjEsvvTQ2bdrU5mNyuVzMnTs3KioqoqSkJCoqKmLu3LmRy+Uy+QYAAAqhUxH13e9+N+bMmRMvvfRSTJkyJf75n/85zjrrrPjpT38aZ5xxRtx///2t9t+3b19MmjQp7rjjjjj55JNjzpw58bd/+7dxxx13xKRJk2Lfvn2ZfjMAAN2lUyeWn3766bFu3bo4++yzW21/7LHH4rzzzosvf/nL8alPfSpKSkoiImLZsmWxefPmmDdvXtx2220t+y9atCiWLFkSy5YtixtvvDGDbwMAoHsV5fP5fBYDnX/++bF27dr47W9/GxMnTox8Ph8nnHBC5HK5ePXVV2Pw4MEt++7fvz9Gjx4d73vf+2LHjh1RVFTUoefI5XJRVlaWxXSBjGX513kbN27MbKwJEyZkNlZEtn+dd8EFF2Q2VpbzAiIaGhqitLS03X0yO7F8wIABERFRXPznxa26urrYtWtXnHnmma0CKiJi0KBB8fGPfzx27twZL774YlZTAADoNplE1MsvvxwPP/xwjBw5MsaPHx8Rf46oiIhx48a1+ZiD2w/uBwDQmxz1xTYPHDgQ06dPj6ampli2bFnLkn5DQ0NExGEPvx1cIju4X1uampqiqamp5Wt/0QcA9BRHtRLV3Nwc11xzTaxbty5mz54d06dPz2peERGxdOnSKCsra7mNGTMm0/EBAFIlR1Q+n4/Zs2fHqlWr4qqrroq77rqr1f0HV6AOt9J0cFWpvRPF58+fHw0NDS23HTt2pE4XACBTSYfzmpubY9asWbFixYqorq6Ompqa6NevdY8d6ZynI50zFRFRUlLScrkEAICepNMrUe8NqMsvvzzuueeeNv+0edy4cTF69OjYsGHDIRfV3L9/f6xbty5Gjx4dY8eOTZ89AECBdCqimpubY+bMmbFixYq47LLLYtWqVYe9NkxRUVHMmjUr9u7dG0uWLGl139KlS6O+vj5mzZrV4WtEAQD0JJ06nLdkyZKoqamJIUOGxIc//OG4+eabD9ln2rRpUVVVFRER8+bNi5/97GexbNmyePrpp2PChAmxZcuWWLNmTVRVVcW8efMy+SYAALpbpyJq27ZtERGxd+/euOWWW9rcp7KysiWiBg8eHLW1tXHjjTfGj370o6itrY2RI0fGnDlzYtGiRYdchBMAoLfI7GNfuoOPfYGey8e+dJ6PfYGeq1s/9gUA4Fhy1FcsB4iIOOmkkzIbK+vVoyzNnTs3s7GsHkHvZiUKACCBiAIASCCiAAASiCgAgAQiCgAggYgCAEggogAAEogoAIAEIgoAIIGIAgBIIKIAABKIKACABCIKACCBiAIASCCiAAASiCgAgAQiCgAggYgCAEhQXOgJAIVTUVGR2Vhr167NbKwsXXfddZmO9+CDD2Y6HtB7WYkCAEggogAAEogoAIAEIgoAIIGIAgBIIKIAABKIKACABCIKACCBiAIASCCiAAASiCgAgAQiCgAggYgCAEggogAAEogoAIAEIgoAIIGIAgBIIKIAABKIKACABMWFngBQOF/4whcyG+vEE0/MbKwsPfroo5mOl8/nMx0P6L2sRAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACYoLPQGg484666xMx/vKV76S6XgAxxIrUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAguJCTwDouLPPPjvT8YYMGZLpeFnZunVrZmPt3bs3s7EA3stKFABAAhEFAJBARAEAJOhURO3cuTO+9a1vxdSpU+PEE0+MgQMHxsiRI+PSSy+NTZs2HbL/4sWLo6ioqM3boEGDMvsmAAC6W6dOLP/ud78bt912W5x00kkxZcqUGDFiRNTV1cUDDzwQDzzwQPzgBz+Iz3zmM4c8bsaMGVFZWdn6iYud0w4A9F6dKpnTTz891q1bd8hfCD322GNx3nnnxZe//OX41Kc+FSUlJa3uv/rqq+Occ8456skCAPQUnTqcd8kll7T5J9Znn312nHvuubF79+545plnMpscAEBPldkxtQEDBvx5wDYO0z322GPxxBNPRP/+/eOUU06JyZMnH7JaBQDQm2QSUS+//HI8/PDDMXLkyBg/fvwh9y9cuLDV16NGjYqVK1fGlClT2h23qakpmpqaWr7O5XJZTBcA4Kgd9SUODhw4ENOnT4+mpqZYtmxZ9O/fv+W+qqqqWLlyZWzbti3efvvtqKuri5tuuin27NkTF198cWzZsqXdsZcuXRplZWUttzFjxhztdAEAMnFUEdXc3BzXXHNNrFu3LmbPnh3Tp09vdf+0adPic5/7XFRUVMSgQYNi7NixsWDBgvj2t78d+/fvj5tvvrnd8efPnx8NDQ0ttx07dhzNdAEAMpMcUfl8PmbPnh2rVq2Kq666Ku66664OP3bGjBlRXFwcGzZsaHe/kpKSKC0tbXUDAOgJkiKqubk5Zs6cGXfffXdUV1dHTU1N9OvX8aEGDhwYQ4cOjbfeeivl6QEACq7TEdXc3ByzZs2KFStWxOWXXx733HNPq/OgOqKuri7q6+sPuQAnAEBv0amIOrgCtWLFirjsssti1apVhw2oxsbG+P3vf3/I9vr6+pg5c2ZERFRXVydMGQCg8Dp1iYMlS5ZETU1NDBkyJD784Q+3eWL4tGnToqqqKt5888049dRTY+LEiTF+/PgYMWJE7Ny5M9asWRNvvvlmTJkyJebMmZPZNwIA0J06FVHbtm2LiIi9e/fGLbfc0uY+lZWVUVVVFcOGDYtrr702Nm7cGD//+c9jz549MXjw4Bg/fnxcddVVMWvWrE4fBgQA6Ck6FVE1NTVRU1PToX1LS0vjzjvvTJkTAECPl9nHvgDHtiNdPLczzjvvvMzG2r17d2ZjAbzXUV+xHADgWCSiAAASiCgAgAQiCgAggYgCAEggogAAEogoAIAEIgoAIIGIAgBIIKIAABKIKACABCIKACCBiAIASCCiAAASiCgAgAQiCgAggYgCAEggogAAEogoAIAERfl8Pl/oSXRULpeLsrKyQk8DAOjjGhoaorS0tN19rEQBACQQUQAACUQUAEACEQUAkEBEAQAkEFEAAAlEFABAAhEFAJBARAEAJBBRAAAJRBQAQAIRBQCQQEQBACQQUQAACUQUAEACEQUAkEBEAQAk6FURlc/nCz0FAOAY0JHm6FUR1djYWOgpAADHgI40R1G+Fy3vNDc3x65du2Lo0KFRVFR02P1yuVyMGTMmduzYEaWlpd04QyK8/j2Bn0Fhef0Ly+tfWL399c/n89HY2BijR4+Ofv3aX2sq7qY5ZaJfv35xwgkndHj/0tLSXvkD7Cu8/oXnZ1BYXv/C8voXVm9+/cvKyjq0X686nAcA0FOIKACABH0yokpKSmLRokVRUlJS6Kkck7z+hednUFhe/8Ly+hfWsfT696oTywEAeoo+uRIFANDVRBQAQAIRBQCQQEQBACToUxH129/+Ni688MIoLy+PwYMHx+mnnx733ntvoad1zKisrIyioqI2b1/60pcKPb0+YdWqVfHFL34xJk6cGCUlJVFUVBQ1NTWH3T+Xy8XcuXOjoqIiSkpKoqKiIubOnRu5XK77Jt3HdOZnsHjx4sO+JwYNGtS9E+8Ddu7cGd/61rdi6tSpceKJJ8bAgQNj5MiRcemll8amTZvafIz3QHY6+/ofC7//veqK5e2pra2N888/PwYOHBif/exno6ysLFavXh1XXnllbNu2La6//vpCT/GYUFZWFl/72tcO2T5x4sTun0wftGDBgti+fXsMHz48Ro0aFdu3bz/svvv27YtJkybF5s2bY8qUKVFdXR1btmyJO+64Ix555JFYv359DB48uBtn3zd05mdw0IwZM6KysrLVtuLiPvN/v93mu9/9btx2221x0kknxZQpU2LEiBFRV1cXDzzwQDzwwAPxgx/8ID7zmc+07O89kK3Ovv4H9enf/3wfcODAgfxJJ52ULykpyT/11FMt23O5XP4jH/lIvri4OP/CCy8UcIbHhoqKinxFRUWhp9Gn/fKXv8xv27Ytn8/n80uXLs1HRH7FihVt7rtw4cJ8ROTnzZvX5vaFCxd29XT7pM78DBYtWpSPiPwjjzzSfRPsw3784x/n161bd8j2devW5QcMGJAfNmxYfv/+/S3bvQey1dnX/1j4/e8Th/N+/etfx9atW+OKK66I0047rWX70KFD44Ybboh33303VqxYUcAZQjYmT54cFRUVR9wvn8/H8uXLY8iQIbFw4cJW982fPz/Ky8vj+9//fuRdJq7TOvozIHuXXHJJnH322YdsP/vss+Pcc8+N3bt3xzPPPBMR3gNdoTOv/7GiT6yn1dbWRkTE1KlTD7nv4LZHH320O6d0zGpqaoqVK1fGzp07o7y8PM4444w49dRTCz2tY05dXV3s2rUrzj///EMOVwwaNCg+/vGPx09/+tN48cUXY9y4cQWa5bHjscceiyeeeCL69+8fp5xySkyePPmYuJpzdxowYEBE/P9hIu+B7vWXr/979eXf/z4RUXV1dRERbb4RysvLY/jw4S370LVeffXVuPrqq1ttu+CCC+Kee+6J4cOHF2ZSx6D23hPv3V5XV+c/IN3gL1dCRo0aFStXrowpU6YUaEZ9y8svvxwPP/xwjBw5MsaPHx8R3gPdqa3X/7368u9/nzic19DQEBF/Pqm5LaWlpS370HWuueaaqK2tjddffz1yuVxs3LgxPvnJT8ZDDz0UF198sWXzbtSR98R796NrVFVVxcqVK2Pbtm3x9ttvR11dXdx0002xZ8+euPjii2PLli2FnmKvd+DAgZg+fXo0NTXFsmXLon///hHhPdBdDvf6Rxwbv/99YiWKnuEv/7XxsY99LB588MGYNGlSrF+/Pn7xi1/ERRddVKDZQfebNm1aq6/Hjh0bCxYsiOOPPz6+8IUvxM033xw//OEPCzO5PqC5uTmuueaaWLduXcyePTumT59e6CkdU470+h8Lv/99YiXq4L80Dvcvilwud9h/jdC1+vXrF5///OcjImLDhg0Fns2xoyPviffuR/eaMWNGFBcXe08chXw+H7Nnz45Vq1bFVVddFXfddVer+70HutaRXv/29KXf/z4RUe89tv2X6uvr44033nDMu4AOngv11ltvFXgmx4723hPv3e59URgDBw6MoUOHek8kam5ujpkzZ8bdd98d1dXVUVNTE/36tf7PmfdA1+nI69+evvT73yciatKkSRERsXbt2kPuO7jt4D50v4NXsv3Li63RdcaNGxejR4+ODRs2xL59+1rdt3///li3bl2MHj06xo4dW6AZHtvq6uqivr7eeyJBc3NzzJo1K1asWBGXX3553HPPPa3OwznIe6BrdPT1b09f+v3vExF13nnnxYc+9KG49957Y/PmzS3bGxsb46abbori4uJD/mKMbD333HOxZ8+eQ7avX78+br/99igpKYlLLrmk+yd2jCoqKopZs2bF3r17Y8mSJa3uW7p0adTX18esWbOiqKioQDPs+xobG+P3v//9Idvr6+tj5syZERFRXV3d3dPq1Q6ugKxYsSIuu+yyWLVq1WH/A+49kL3OvP7Hyu9/Ub6P/MnUI488Eueff36UlJREdXV1lJaWxurVq+NPf/pT3HzzzfGNb3yj0FPs0xYvXhzLli2L8847LyorK6OkpCSeffbZWLt2bfTr1y/uuuuumDVrVqGn2estX7481q9fHxERzzzzTDz11FNx5plntvxretq0aS0nc+7bty/OOuuslo+8mDBhQmzZsiXWrFkTVVVVPvIiUUd/Btu2bYsPfvCDMXHixBg/fnyMGDEidu7cGWvWrIk333wzpkyZEg8++GAMHDiwkN9Or7J48eK48cYbY8iQIfHVr361zWsSTZs2LaqqqiLCeyBrnXn9j5nf/4JdK70LbNq0KX/BBRfky8rK8scdd1x+4sSJ+VWrVhV6WseE2tra/Gc+85n82LFj80OHDs0PGDAgf8IJJ+Q/+9nP5jdt2lTo6fUZM2bMyEfEYW+LFi1qtf+ePXvyc+bMyY8ZMyY/YMCA/JgxY/Jz5szJ79mzpzDfQB/Q0Z9BQ0ND/tprr81PmDAhP3z48HxxcXG+rKwsf9ZZZ+Xvuuuu/LvvvlvYb6QXOtJrH218BI/3QHY68/ofK7//fWYlCgCgO/WJc6IAALqbiAIASCCiAAASiCgAgAQiCgAggYgCAEggogAAEogoAIAEIgoAIIGIAgBIIKIAABKIKACABCIKACDB/wFQvRF3frc1eQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 700x700 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from sklearn import svm\n",
    "from sklearn import tree\n",
    "from sklearn.linear_model import LogisticRegression, LinearRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_absolute_error\n",
    "from sklearn.metrics import classification_report, confusion_matrix\n",
    "import seaborn as sns\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.multiclass import OneVsRestClassifier\n",
    "from sklearn.multiclass import OneVsOneClassifier\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "\n",
    "import time\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "color_list = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "label_size = 18 # Label size\n",
    "ticklabel_size = 14 # Tick label size\n",
    "\n",
    "# Load the MNIST dataset to display\n",
    "imgDisp = torchvision.datasets.MNIST(root='./data', train=False, download=True)\n",
    "img, label = imgDisp[0]\n",
    "\n",
    "print(f'Image size is {img.size}')\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7,7))\n",
    "ax.imshow(img, cmap='gray') # Display image\n",
    "ax.tick_params(axis='both', which='major', labelsize=ticklabel_size) # Set tick label size\n",
    "ax.set_title(f\"Label: {label}\", fontsize=label_size)\n",
    "# plt.savefig(f'exp_character{label}.png', dpi=300) # Make figure clearer\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4429d946-9f01-415c-99db-8fd32a0e0257",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. 使用torch库载入MNIST数据集\n",
    "# 定义数据预处理操作，将图像转换为张量并进行归一化\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),  # 把PIL.Image类型的图像数据转换为PyTorch张量，像素值范围从[0, 255]变为[0, 1]\n",
    "    transforms.Normalize((0.1307,), (0.3081,))  # 按照MNIST数据集常用的均值和标准差进行归一化，有助于模型训练\n",
    "])\n",
    "\n",
    "# 加载MNIST训练集，设置下载为True（若数据集不存在则自动下载），同时应用前面定义的预处理操作\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "\n",
    "# 加载MNIST测试集，同样应用预处理操作\n",
    "testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "75667a78-4ef0-4b63-89a7-612a5786f7e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义函数，将数据集转换为numpy数组，并根据数字大小生成二分类标签（小于5为一类，大于等于5为另一类）\n",
    "def convert_to_numpy(dataset):\n",
    "    X = []\n",
    "    y = []\n",
    "    for data, label in dataset:\n",
    "        # 将图像张量展平为一维的numpy数组，方便作为特征向量输入模型\n",
    "        X.append(data.numpy().reshape(-1))\n",
    "        if label < 5:\n",
    "            y.append(0)  # 数字小于5标记为0类（表示小数）\n",
    "        else:\n",
    "            y.append(1)  # 数字大于等于5标记为1类（表示大数）\n",
    "    return np.array(X), np.array(y)\n",
    "\n",
    "# 将训练集转换为numpy数组并生成对应的二分类标签\n",
    "X_train, y_train = convert_to_numpy(trainset)\n",
    "\n",
    "# 对测试集做同样的操作\n",
    "X_test, y_test = convert_to_numpy(testset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d929b091-5d7f-45da-a5df-35287890ddcd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "逻辑回归模型判断大小数准确率: 0.8707\n",
      "逻辑回归模型判断大小数精确率: 0.8676834295136027\n",
      "逻辑回归模型判断大小数召回率: 0.8660769389014606\n",
      "逻辑回归模型判断大小数F1分数: 0.8668794399258726\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anaconda\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    }
   ],
   "source": [
    "# 创建逻辑回归模型实例\n",
    "logreg = LogisticRegression()\n",
    "\n",
    "# 使用训练集数据训练逻辑回归模型\n",
    "logreg.fit(X_train, y_train)\n",
    "\n",
    "# 使用训练好的模型在测试集上进行预测\n",
    "y_pred = logreg.predict(X_test)\n",
    "\n",
    "# 计算准确率并打印\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(\"逻辑回归模型判断大小数准确率:\", accuracy)\n",
    "\n",
    "# 计算精确率并打印\n",
    "precision = precision_score(y_test, y_pred)\n",
    "print(\"逻辑回归模型判断大小数精确率:\", precision)\n",
    "\n",
    "# 计算召回率并打印\n",
    "recall = recall_score(y_test, y_pred)\n",
    "print(\"逻辑回归模型判断大小数召回率:\", recall)\n",
    "\n",
    "# 计算F1分数并打印\n",
    "f1 = f1_score(y_test, y_pred)\n",
    "print(\"逻辑回归模型判断大小数F1分数:\", f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8e5166f6-4060-4a77-9ebc-808b2af20e43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 loss: 0.4685239857479708\n",
      "Epoch 2 loss: 0.33298324512393235\n",
      "Epoch 3 loss: 0.3117782960234802\n",
      "Epoch 4 loss: 0.3006232999630574\n",
      "Epoch 5 loss: 0.2930099372861228\n",
      "Epoch 6 loss: 0.2882642666858905\n",
      "Epoch 7 loss: 0.2839406344698055\n",
      "Epoch 8 loss: 0.28079443517079483\n",
      "Epoch 9 loss: 0.2781492965752636\n",
      "Epoch 10 loss: 0.2757528300430856\n",
      "Softmax回归模型识别手写字母准确率: 92.24%\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# 1. 使用torch库载入MNIST数据集\n",
    "# 定义数据预处理操作，将图像转换为张量并进行归一化\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),  # 把PIL.Image类型的图像数据转换为PyTorch张量，像素值范围从[0, 255]变为[0, 1]\n",
    "    transforms.Normalize((0.1307,), (0.3081,))  # 按照MNIST数据集常用的均值和标准差进行归一化，有助于模型训练\n",
    "])\n",
    "\n",
    "# 加载MNIST训练集，设置下载为True（若数据集不存在则自动下载），同时应用前面定义的预处理操作\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "# 创建训练集数据加载器，设置批量大小为64，并且打乱数据顺序\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)\n",
    "\n",
    "# 加载MNIST测试集，同样应用预处理操作\n",
    "testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "# 创建测试集数据加载器，批量大小也设为64，测试集一般不需要打乱顺序\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)\n",
    "\n",
    "# 定义Softmax回归模型类，继承自nn.Module（PyTorch中用于定义神经网络模型的基类）\n",
    "class SoftmaxRegression(nn.Module):\n",
    "    def __init__(self, input_size, num_classes):\n",
    "        super(SoftmaxRegression, self).__init__()\n",
    "        self.linear = nn.Linear(input_size, num_classes)  # 定义线性层，将输入特征映射到类别数量对应的维度\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.linear(x)\n",
    "        return out\n",
    "\n",
    "# 设置输入特征大小（MNIST图像展平后为784维）和类别数量（手写字母有10个类别）\n",
    "input_size = 784\n",
    "num_classes = 10\n",
    "\n",
    "# 创建Softmax回归模型实例，并将其移动到合适的设备（GPU或CPU）\n",
    "model = SoftmaxRegression(input_size, num_classes).to('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# 定义损失函数，这里使用交叉熵损失函数，常用于多分类任务\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# 定义优化器，使用随机梯度下降（SGD），设置学习率为0.01\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.01)\n",
    "\n",
    "# 训练模型，设置训练轮数为100\n",
    "for epoch in range(10):\n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(trainloader, 0):\n",
    "        # 通过DataLoader获取输入数据和对应的标签，此时数据和标签已经是张量类型且按批次组织好了\n",
    "        inputs, labels = data\n",
    "        inputs = inputs.view(-1, 784).to('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        labels = labels.to('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "        optimizer.zero_grad()  # 梯度清零\n",
    "        outputs = model(inputs)  # 前向传播，通过模型得到输出\n",
    "        loss = criterion(outputs, labels)  # 计算损失\n",
    "        loss.backward()  # 反向传播，计算梯度\n",
    "        optimizer.step()  # 根据梯度更新模型参数\n",
    "\n",
    "        running_loss += loss.item()\n",
    "    print(f'Epoch {epoch + 1} loss: {running_loss / len(trainloader)}')\n",
    "\n",
    "# 在测试集上进行预测并计算准确率\n",
    "correct = 0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        images = images.view(-1, 784).to('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        labels = labels.to('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "print(f'Softmax回归模型识别手写字母准确率: {100 * correct / total}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cd1ded6-a816-4fb2-809d-4bbeb44b23b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import SVC\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
    "\n",
    "# 创建支持向量机模型对象，这里使用线性核函数，可根据实际情况调整核函数类型\n",
    "clf = SVC(kernel='linear')\n",
    "\n",
    "# 使用训练集数据训练支持向量机模型\n",
    "clf.fit(X_train, y_train)\n",
    "\n",
    "# 在测试集上进行预测\n",
    "y_pred = clf.predict(X_test)\n",
    "\n",
    "# 计算准确率并打印\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(\"支持向量机判断大小数准确率:\", accuracy)\n",
    "\n",
    "# 计算精确率并打印\n",
    "precision = precision_score(y_test, y_pred)\n",
    "print(\"支持向量机判断大小数精确率:\", precision)\n",
    "\n",
    "# 计算召回率并打印\n",
    "recall = recall_score(y_test, y_pred)\n",
    "print(\"支持向量机判断大小数召回率:\", recall)\n",
    "\n",
    "# 计算F1分数并打印\n",
    "f1 = f1_score(y_test, y_pred)\n",
    "print(\"支持向量机判断大小数F1分数:\", f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bb5cf66-3019-4dd2-86de-53a6c36961ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ftrExtract(object):\n",
    "    '''\n",
    "    This class is used to extract features of images\n",
    "    '''\n",
    "    def __call__(self, tensor):\n",
    "        tensor = tensor.squeeze() # Compress redundant demensions\n",
    "\n",
    "        mean_width = tensor.mean(axis=0)\n",
    "        mean_height = tensor.mean(axis=1)\n",
    "\n",
    "        std_width = tensor.std(axis=0)\n",
    "        std_height = tensor.std(axis=1)\n",
    "\n",
    "        ftrs = torch.cat([mean_width, mean_height, std_width, std_height])\n",
    "\n",
    "        return ftrs\n",
    "\n",
    "# Define a transform to normalize the data\n",
    "transform = transforms.Compose([transforms.ToTensor(), ftrExtract()])\n",
    "\n",
    "# Load the MNIST dataset\n",
    "trainset = torchvision.datasets.MNIST(root='./Data', train=True, download=True, transform=transform)\n",
    "testset = torchvision.datasets.MNIST(root='./Data', train=False, download=True, transform=transform)\n",
    "\n",
    "# Count number of each class in trainset\n",
    "train_class_counts = {}\n",
    "for _, label in trainset:\n",
    "    if label not in train_class_counts:\n",
    "        train_class_counts[label] = 0\n",
    "    train_class_counts[label] += 1\n",
    "\n",
    "# Count number of each class in testset\n",
    "test_class_counts = {}\n",
    "for _, label in testset:\n",
    "    if label not in test_class_counts:\n",
    "        test_class_counts[label] = 0\n",
    "    test_class_counts[label] += 1\n",
    "\n",
    "# Print results\n",
    "for i in range(10):\n",
    "    cls_counts_train = train_class_counts.get(i, 0)\n",
    "    cls_ratio_train = cls_counts_train / len(trainset)\n",
    "    cls_counts_test = test_class_counts.get(i, 0)\n",
    "    cls_ratio_test = cls_counts_test / len(testset)\n",
    "\n",
    "    print(f\"Class {i}: Trainset - {cls_counts_train} ({cls_ratio_train:.2%}), Testset - {cls_counts_test} ({cls_ratio_test:.2%})\")\n",
    "\n",
    "batch_size = 42\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "\n",
    "# Get a batch of training data\n",
    "dataiter = iter(trainloader)\n",
    "data, labels = next(dataiter)\n",
    "\n",
    "input_size = data[0].numpy().shape[0]\n",
    "print(f'Input_size is {input_size}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bfe9da0-90f5-4b11-98de-21decc4b09c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert data to numpy arrays\n",
    "X_train = []\n",
    "y_train = []\n",
    "for batch_image, batch_label in trainloader:\n",
    "    X_train.append(batch_image.view(-1, input_size).numpy())\n",
    "    y_train.append(batch_label.numpy())\n",
    "\n",
    "X_train = np.vstack(X_train)\n",
    "y_train = np.concatenate(y_train)\n",
    "\n",
    "print(f'Shapes of X_train and Y_train: {X_train.shape} and {y_train.shape}')\n",
    "\n",
    "X_test = []\n",
    "y_test = []\n",
    "for batch_image, batch_label in testloader:\n",
    "    X_test.append(batch_image.view(-1, input_size).numpy())\n",
    "    y_test.append(batch_label.numpy())\n",
    "\n",
    "X_test = np.vstack(X_test)\n",
    "y_test = np.concatenate(y_test)\n",
    "\n",
    "print(f'Shapes of X_test and y_test: {X_test.shape} and {y_test.shape}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "340b98fe-cf31-4978-8139-839a6f9738a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the linear regression model\n",
    "lr_model = LinearRegression()\n",
    "\n",
    "# Train the model\n",
    "lr_model.fit(X_train, y_train)\n",
    "\n",
    "# Make predictions on the test set\n",
    "y_pred = lr_model.predict(X_test)\n",
    "\n",
    "# Round predictions to nearest integer for classification\n",
    "y_pred_constrained = np.clip(y_pred, 0, np.max(y_test))\n",
    "y_pred_rounded = np.round(y_pred_constrained).astype(int)\n",
    "print(f\"Predicted classes: {np.unique(y_pred_rounded)}\")\n",
    "\n",
    "# Calculate accuracy\n",
    "accuracy = np.mean(y_pred_rounded == y_test)\n",
    "print(f\"Real classes: {np.unique(y_test)}\")\n",
    "print(f\"Accuracy of linear regression model: {accuracy:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49d5b871-f29d-4254-91ea-6606b19871ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def manual_confusion_matrix(y_true, y_pred, num_classes):\n",
    "    cm = np.zeros((num_classes, num_classes), dtype=int)\n",
    "    for t, p in zip(y_true, y_pred):\n",
    "        cm[t][p] += 1\n",
    "    return cm\n",
    "\n",
    "# Get number of character types\n",
    "num_classes = len(np.unique(y_test)) \n",
    "\n",
    "# Calculate confusion matrix\n",
    "cm = manual_confusion_matrix(y_test, y_pred_rounded, num_classes)\n",
    "\n",
    "# Print the results in the specified format\n",
    "for i in range(num_classes):\n",
    "    output = f\"Real label {i}, \"\n",
    "    for j in range(num_classes):\n",
    "        output += f\"Predict label {j} ({cm[i, j]}), \"\n",
    "    print(output[:-2])  # Remove the last comma and space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaff628d-4b72-40c4-91c9-ddc39793f058",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the linear regression model\n",
    "lr_model = LogisticRegression(max_iter=1000)\n",
    "\n",
    "# Train the model\n",
    "lr_model.fit(X_train, y_train)\n",
    "\n",
    "# Make predictions on the test set\n",
    "y_pred = lr_model.predict(X_test)\n",
    "y_pred_proba = lr_model.predict_proba(X_test)\n",
    "\n",
    "# Round predictions to nearest integer for classification\n",
    "y_pred_rounded = np.round(y_pred).astype(int)\n",
    "print(f\"Predicted classes: {np.unique(y_pred_rounded)}\")\n",
    "\n",
    "# Calculate accuracy\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Accuracy of logistic regression model: {accuracy:.4f}\")\n",
    "\n",
    "# Calculate and print classification report\n",
    "print(\"\\nClassification Report:\")\n",
    "print(classification_report(y_test, y_pred))\n",
    "\n",
    "cm = confusion_matrix(y_test, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b64dc56-3e3f-4fd9-9ac4-9c32fb8d1a9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select an example of character 1 randomly\n",
    "# Select an example of 1 randomly\n",
    "char_1_indices = np.where(y_test == 1)[0]\n",
    "random_char_1_index = np.random.choice(char_1_indices)\n",
    "random_char_1 = X_test[random_char_1_index]\n",
    "\n",
    "# Select an example of 2 randomly\n",
    "char_2_indices = np.where(y_test == 2)[0]\n",
    "random_char_2_index = np.random.choice(char_2_indices)\n",
    "random_char_2 = X_test[random_char_2_index]\n",
    "\n",
    "# Select an example of 6 randomly\n",
    "char_6_indices = np.where(y_test == 6)[0]\n",
    "random_char_6_index = np.random.choice(char_6_indices)\n",
    "random_char_6 = X_test[random_char_6_index]\n",
    "\n",
    "# Get predictions and probabilities for these examples\n",
    "examples = [random_char_1, random_char_2, random_char_6]\n",
    "example_predictions = lr_model.predict(examples)\n",
    "example_probabilities = lr_model.predict_proba(examples)\n",
    "\n",
    "print(\"Predicted classes: \", example_predictions)\n",
    "\n",
    "# Display probabilities in percentage style, one line per example\n",
    "print(\"Probabilities:\")\n",
    "for i, probs in enumerate(example_probabilities):\n",
    "    prob_strings = [f\"{prob:.2%}\" for _, prob in enumerate(probs)]\n",
    "    print(f\"Example {i+1}: {', '.join(prob_strings)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32cb04cf-bd3e-4c02-9300-ea505967a772",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x):\n",
    "    return 1 / (1 + np.exp(-x))\n",
    "\n",
    "# Generate x values\n",
    "x = np.linspace(-10, 10, 100)\n",
    "\n",
    "# Calculate sigmoid values\n",
    "y = sigmoid(x)\n",
    "\n",
    "# Plot the sigmoid curve\n",
    "fig, ax = plt.subplots(figsize=(10, 6))\n",
    "ax.plot(x, y, 'b-', linewidth=2)\n",
    "ax.set_xlabel('z', fontsize=label_size)\n",
    "ax.set_ylabel('y', fontsize=label_size)\n",
    "\n",
    "# Set x-ticks\n",
    "xticks = np.arange(-10.0, 10.1, 2.5)\n",
    "ax.set_xticks(xticks)\n",
    "\n",
    "# Modify tick labels\n",
    "xticklabels = ['-∞' if x == -10 else ('+∞' if x == 10 else str(x)) for x in xticks]\n",
    "ax.set_xticklabels(xticklabels)\n",
    "ax.tick_params(axis='both', which='major', labelsize=ticklabel_size)\n",
    "\n",
    "ax.set_xlim(-10, 10)\n",
    "ax.set_ylim(-0.05, 1.05)\n",
    "\n",
    "# Add vertical line at x=0\n",
    "ax.axvline(x=0, color='r', linestyle='--')\n",
    "\n",
    "# Add horizontal lines at y=0.5 and y=1\n",
    "ax.axhline(y=0.5, color='g', linestyle='--')\n",
    "ax.axhline(y=1, color='g', linestyle='--')\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.savefig('sigmoid_function.png', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4481374b-229f-404f-9b1d-5ed73b9e5e4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract features and labels from trainset\n",
    "x_train = []\n",
    "y_train = []\n",
    "for image, label in trainset:\n",
    "    x_train.append(image.numpy())\n",
    "    y_train.append(1 if label == 1 else 0)  # Set label to 1 for character 1, 0 otherwise\n",
    "\n",
    "x_train = np.array(x_train)\n",
    "y_train = np.array(y_train)\n",
    "\n",
    "# Extract features and labels from trainset\n",
    "x_test = []\n",
    "y_test = []\n",
    "for image, label in testset:\n",
    "    x_test.append(image.numpy())\n",
    "    y_test.append(1 if label == 1 else 0)  # Set label to 1 for character 1, 0 otherwise\n",
    "\n",
    "x_test = np.array(x_test)\n",
    "y_test = np.array(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dad7c159-6799-471c-9e60-6f9d0ee32213",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Generate prediction values\n",
    "y_pred = np.linspace(0.001, 0.999, 1000)\n",
    "\n",
    "# Compute loss for y_true = 1 and y_true = 0\n",
    "loss_y1 = binary_cross_entropy(1, y_pred)\n",
    "loss_y0 = binary_cross_entropy(0, y_pred)\n",
    "\n",
    "# Plotting\n",
    "fig, ax_y1 = plt.subplots(figsize=(10, 6))\n",
    "ax_y1.plot(y_pred, loss_y1, label='y_true = 1', color='blue')\n",
    "ax_y1.set_xlabel('Predicted y', fontsize=label_size)\n",
    "ax_y1.set_ylabel('Loss', fontsize=label_size)\n",
    "ax_y1.tick_params(axis='both', which='major', labelsize=ticklabel_size)\n",
    "plt.tight_layout()\n",
    "plt.savefig('binary_cross_entropy_loss1.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "# Plotting\n",
    "fig, ax_y0 = plt.subplots(figsize=(10, 6))\n",
    "plt.plot(y_pred, loss_y0, label='y_true = 0', color='red')\n",
    "ax_y0.set_xlabel('Predicted y', fontsize=label_size)\n",
    "ax_y0.set_ylabel('Loss', fontsize=label_size)\n",
    "ax_y0.tick_params(axis='both', which='major', labelsize=ticklabel_size)\n",
    "plt.tight_layout()\n",
    "plt.savefig('binary_cross_entropy_loss0.png', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d136811-b244-4f2f-a852-cd7c05a24a24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate x values\n",
    "x = np.linspace(0.01, 5, 1000)\n",
    "\n",
    "# Compute natural logarithm\n",
    "y = -np.log(x)\n",
    "\n",
    "# Create the plot\n",
    "fig, ax_log = plt.subplots(figsize=(6, 6))\n",
    "\n",
    "ax_log.plot(x, y, label='ln(x)', color='k', linewidth=2)\n",
    "\n",
    "ax_log.set_xlabel('x', fontsize=label_size)\n",
    "ax_log.set_ylabel('log$_e$(x)', fontsize=label_size)\n",
    "ax_log.tick_params(axis='both', which='major', labelsize=ticklabel_size)\n",
    "\n",
    "ax_log.set_xlim(-0.05, 5)\n",
    "ax_log.set_ylim(-2, 4)\n",
    "# Add vertical line at x=1\n",
    "ax_log.axvline(x=1, color='gray', linestyle='--')\n",
    "\n",
    "# Add horizontal line at y=0\n",
    "ax_log.axhline(y=0, color='gray', linestyle='--')\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.savefig('natural_logarithm_function.png', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29357e55-27c8-4d4e-81a1-9e1232d4b82d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the model\n",
    "w, b = train_logistic_regression(x_train, y_train)\n",
    "\n",
    "y_pred, y_proba = predict(x_test, w, b)\n",
    "\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "precision = precision_score(y_test, y_pred)\n",
    "recall = recall_score(y_test, y_pred)\n",
    "f1 = f1_score(y_test, y_pred)\n",
    "\n",
    "print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, Accuracy: {accuracy:.4f}, F1-Score: {f1:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0915ede-2de0-46fd-8e70-1669d3d0a1e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Random select 3 examples from imgDisp and testset\n",
    "np.random.seed(42)\n",
    "idx = np.random.choice(len(imgDisp), 3)\n",
    "\n",
    "# Select instances\n",
    "imgDisp_select = [imgDisp[i] for i in idx]\n",
    "x_select = x_test[idx]\n",
    "y_select = y_test[idx]\n",
    "\n",
    "y_select_pred, y_select_proba = predict(x_select, w, b)\n",
    "\n",
    "# Check the selected instances' labels are the same\n",
    "for i in range(len(idx)):\n",
    "    print(f'Sample {i+1}: imgDisp label is {imgDisp_select[i][1]}, x label is {y_select[i]}')\n",
    "\n",
    "    # Display image from imgDisp\n",
    "    fig, ax = plt.subplots(figsize=(7,7))\n",
    "    ax.imshow(imgDisp_select[i][0], cmap='gray')\n",
    "    ax.tick_params(axis='both', which='major', labelsize=ticklabel_size) # Set tick label size\n",
    "    ax.set_title(f\"Label: {imgDisp_select[i][1]}, Prediction: {y_select_proba[i]:.4f}\", fontsize=label_size)\n",
    "\n",
    "    # plt.savefig(f'binary_prediction_{i+1}.png', dpi=300) # Make figure clearer\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60123e84-f506-4854-acce-32ae46730221",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define SVM classifier\n",
    "mdl_svm = svm.SVC(kernel='linear', probability=True)\n",
    "\n",
    "# Train model\n",
    "start_time = time.time()\n",
    "mdl_svm.fit(x_train, y_train)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f'Training time: {end_time - start_time:.2f} seconds')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38184ecb-ffa4-42ab-9c61-51bafbe8dd94",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make predictions and evaluate the model\n",
    "y_pred_svm = mdl_svm.predict(x_test)\n",
    "y_proba_svm = mdl_svm.predict_proba(x_test) # Output ratio\n",
    "\n",
    "accuracy = accuracy_score(y_test, y_pred_svm)\n",
    "precision = precision_score(y_test, y_pred_svm)\n",
    "recall = recall_score(y_test, y_pred_svm)\n",
    "f1 = f1_score(y_test, y_pred_svm)\n",
    "\n",
    "print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, Accuracy: {accuracy:.4f}, F1-Score: {f1:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc2eabea-c441-487a-9f0f-d903e7725da5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define DecisionTree classifier\n",
    "mdl_dt = tree.DecisionTreeClassifier()\n",
    "\n",
    "# Train model\n",
    "start_time = time.time()\n",
    "mdl_dt.fit(x_train, y_train)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f'Training time: {end_time - start_time:.2f} seconds')\n",
    "\n",
    "# Define Random Forest classifier\n",
    "mdl_rf = RandomForestClassifier(n_estimators=100)\n",
    "\n",
    "# Train model\n",
    "start_time = time.time()\n",
    "mdl_rf.fit(x_train, y_train)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f'Training time: {end_time - start_time:.2f} seconds')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52e8ec4b-1ef7-477a-ad6e-821c65d55365",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_dt = mdl_dt.predict(x_test)\n",
    "y_proba_dt = mdl_dt.predict_proba(x_test) # Output ratio\n",
    "\n",
    "accuracy = accuracy_score(y_test, y_pred_dt)\n",
    "precision = precision_score(y_test, y_pred_dt)\n",
    "recall = recall_score(y_test, y_pred_dt)\n",
    "f1 = f1_score(y_test, y_pred_dt)\n",
    "\n",
    "print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, Accuracy: {accuracy:.4f}, F1-Score: {f1:.4f}')\n",
    "\n",
    "y_pred_rf = mdl_rf.predict(x_test)\n",
    "y_proba_rf = mdl_rf.predict_proba(x_test) # Output ratio\n",
    "\n",
    "accuracy = accuracy_score(y_test, y_pred_rf)\n",
    "precision = precision_score(y_test, y_pred_rf)\n",
    "recall = recall_score(y_test, y_pred_rf)\n",
    "f1 = f1_score(y_test, y_pred_rf)\n",
    "\n",
    "print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, Accuracy: {accuracy:.4f}, F1-Score: {f1:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b672149-5a62-4111-b710-4870da40cfbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract features and labels from trainset\n",
    "x_train = []\n",
    "y_train = []\n",
    "for image, label in trainset:\n",
    "    x_train.append(image.numpy())\n",
    "    y_train.append(label)\n",
    "\n",
    "x_train = np.array(x_train)\n",
    "y_train = np.array(y_train)\n",
    "\n",
    "# Extract features and labels from trainset\n",
    "x_test = []\n",
    "y_test = []\n",
    "for image, label in testset:\n",
    "    x_test.append(image.numpy())\n",
    "    y_test.append(label)\n",
    "\n",
    "x_test = np.array(x_test)\n",
    "y_test = np.array(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ac73ad-c862-42be-8cb0-05bc138e4454",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define logic multi-classifier\n",
    "mdl_logic_ovr = OneVsRestClassifier(LogisticRegression(max_iter=1000))\n",
    "\n",
    "# Train model\n",
    "start_time = time.time()\n",
    "mdl_logic_ovr.fit(x_train, y_train)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f'Training time: {end_time - start_time:.2f} seconds')\n",
    "\n",
    "# Make predictions and evaluate the model\n",
    "y_pred_logic_ovr = mdl_logic_ovr.predict(x_test)\n",
    "y_proba_logic_ovr = mdl_logic_ovr.predict_proba(x_test) # Output ratio\n",
    "\n",
    "accuracy = accuracy_score(y_test, y_pred_logic_ovr)\n",
    "print(f'Accuracy: {accuracy:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e8ef95e-ce3a-47dd-8c54-efcf9267401f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get class list: 0, 1, ..., 9\n",
    "class_list = np.sort(np.unique(y_train))\n",
    "\n",
    "# Create model list\n",
    "mdl_logic_list = []\n",
    "for c in class_list:\n",
    "    mdl_logic_list.append(LogisticRegression(max_iter=1000))\n",
    "\n",
    "# Train models seperately\n",
    "for i in range(len(class_list)):\n",
    "    start_time = time.time()\n",
    "    mdl_logic_list[i].fit(x_train, (y_train == class_list[i]).astype(int))\n",
    "    end_time = time.time()\n",
    "    print(f'Training class {class_list[i]}, Training time: {end_time - start_time:.2f} seconds')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0f83170-4378-4cd0-a9e7-f749e7e8ecb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_num = 10\n",
    "\n",
    "# Random select 3 examples from imgDisp and testset\n",
    "np.random.seed(1)\n",
    "idx = np.random.choice(len(imgDisp), sample_num)\n",
    "\n",
    "# Select instances\n",
    "imgDisp_select = [imgDisp[i] for i in idx]\n",
    "testset_select = [testset[i] for i in idx]\n",
    "\n",
    "# Check the selected instances' labels are the same\n",
    "for i in range(sample_num):\n",
    "    x = testset_select[i][0].view(-1, input_size)\n",
    "\n",
    "    # Using model to predict character\n",
    "    y_pred_list = []\n",
    "    for j in range(len(mdl_logic_list)):\n",
    "        y_pred_list.append(mdl_logic_list[j].predict(x))\n",
    "\n",
    "    y_pred = np.argmax(np.array(y_pred_list), axis=0)[0]\n",
    "\n",
    "    # Display image from imgDisp\n",
    "    fig, ax = plt.subplots(figsize=(7,7))\n",
    "    ax.imshow(imgDisp_select[i][0], cmap='gray')\n",
    "    ax.tick_params(axis='both', which='major', labelsize=ticklabel_size) # Set tick label size\n",
    "    ax.set_title(f\"Label: {imgDisp_select[i][1]}, Prediction Label: {y_pred}\", fontsize=label_size)\n",
    "\n",
    "    print(f'Sample {i+1}: imgDisp label is {imgDisp_select[i][1]}, testset label is {testset_select[i][1]}, predict label is {y_pred}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fa643e6-5917-4f54-b979-903b3e704dee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prediction\n",
    "y_pred_list = []\n",
    "for i in range(len(mdl_logic_list)):\n",
    "    y_pred_list.append(mdl_logic_list[i].predict(x_test))\n",
    "\n",
    "y_pred = np.argmax(np.array(y_pred_list), axis=0)\n",
    "\n",
    "# Accuracy\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f'Accuracy: {accuracy:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6a4e59c-027b-46d9-b74c-e771b6354101",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define logic regression classifier\n",
    "mdl_logic_ovo = OneVsOneClassifier(LogisticRegression(max_iter=1000))\n",
    "\n",
    "# Train model\n",
    "start_time = time.time()\n",
    "mdl_logic_ovo.fit(x_train, y_train)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f'Training time: {end_time - start_time:.2f} seconds')\n",
    "\n",
    "# Make predictions and evaluate the model\n",
    "y_pred = mdl_logic_ovo.predict(x_test)\n",
    "\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f'Accuracy: {accuracy:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c60665c2-f3ed-4715-b70c-dc375c48cafd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get class list: 0, 1, ..., 9\n",
    "class_list = np.sort(np.unique(y_train))\n",
    "\n",
    "# Create model matrix to save models\n",
    "mdl_logic_matrix = {}\n",
    "for cls_p in class_list:\n",
    "    mdl_logic_matrix[cls_p] = {}\n",
    "    for cls_n in class_list:\n",
    "        if cls_p == cls_n:\n",
    "            continue\n",
    "        mdl_logic_matrix[cls_p][cls_n] = LogisticRegression(max_iter=1000)\n",
    "\n",
    "for cls_p in class_list:\n",
    "    # Training data of positive class\n",
    "    x_train_ovo_p = x_train[(y_train == cls_p), :]\n",
    "    y_train_ovo_p = np.ones(x_train_ovo_p.shape[0])\n",
    "\n",
    "    # Testing data of positive class\n",
    "    x_test_ovo_p = x_test[(y_test == cls_p), :]\n",
    "    y_test_ovo_p = np.ones(x_test_ovo_p.shape[0])\n",
    "\n",
    "    for cls_n in class_list:\n",
    "        if cls_p == cls_n:\n",
    "            continue\n",
    "\n",
    "        # Training data of negative class\n",
    "        x_train_ovo_n = x_train[(y_train == cls_n), :]\n",
    "        y_train_ovo_n = np.zeros(x_train_ovo_n.shape[0])\n",
    "\n",
    "        # Testing data of negative class\n",
    "        x_test_ovo_n = x_test[(y_test == cls_n), :]\n",
    "        y_test_ovo_n = np.zeros(x_test_ovo_n.shape[0])\n",
    "\n",
    "        # Concatenate data for training\n",
    "        x_train_ovo = np.concatenate((x_train_ovo_p, x_train_ovo_n), axis=0)\n",
    "        y_train_ovo = np.concatenate((y_train_ovo_p, y_train_ovo_n), axis=0)\n",
    "\n",
    "        # Model training\n",
    "        start_time = time.time()\n",
    "        mdl_logic_matrix[cls_p][cls_n].fit(x_train_ovo, y_train_ovo)\n",
    "        end_time = time.time()\n",
    "\n",
    "        # Concatenate data for testing\n",
    "        x_test_ovo = np.concatenate((x_test_ovo_p, x_test_ovo_n), axis=0)\n",
    "        y_test_ovo = np.concatenate((y_test_ovo_p, y_test_ovo_n), axis=0)\n",
    "\n",
    "        # Test model on sub-task\n",
    "        y_proba_ovo = mdl_logic_matrix[cls_p][cls_n].predict_proba(x_test_ovo) # Output ratio\n",
    "\n",
    "        # Display results\n",
    "        _, (tp, fp, tn, fn) = cls_counts(y_test_ovo, y_proba_ovo[:, 1])\n",
    "        precision, recall, specificity, accuracy, f1 = get_scores(tp, fp, tn, fn)\n",
    "        print(f'Training class {cls_p} ({x_train_ovo_p.shape[0]}) vs class {cls_n} ({x_train_ovo_n.shape[0]}), Training time: {end_time - start_time:.2f} seconds, Precision: {precision:.4f}, Recall (Sensitivity): {recall:.4f}, Specificity: {specificity:.4f}, Accuracy: {accuracy:.4f}, F1-Score: {f1:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4899e435-f9ad-49c7-a6b1-e182a94d8f21",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select class 1\n",
    "x_test_select = x_test[:, :]\n",
    "\n",
    "# Prediction\n",
    "y_pred_counts = np.zeros((x_test_select.shape[0], len(class_list)))\n",
    "\n",
    "for cls_p in class_list:\n",
    "    for cls_n in class_list:\n",
    "        if cls_p == cls_n:\n",
    "            continue\n",
    "\n",
    "        y_pred_counts[:, cls_p] = y_pred_counts[:, cls_p] + mdl_logic_matrix[cls_p][cls_n].predict(x_test_select)\n",
    "\n",
    "y_pred = np.argmax(y_pred_counts, axis=1)\n",
    "\n",
    "# Accuracy\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f'Accuracy: {accuracy:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ac6f35f-b422-450a-93f3-7cb892feddd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "mdl_softmax = LogisticRegression(max_iter=1000, solver='lbfgs')\n",
    "\n",
    "start_time = time.time()\n",
    "mdl_softmax.fit(x_train, y_train)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f'Training time: {end_time - start_time:.2f} seconds')\n",
    "\n",
    "# Evaluate accuracy (or other metrics)\n",
    "y_pred = mdl_softmax.predict(x_test)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(\"Accuracy:\", accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "111e90b3-1b9d-46c5-a97c-8cce7a19be8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# One-hot encoding\n",
    "def one_hot_encode(y, num_classes):\n",
    "    \"\"\"Converts integer labels to one-hot encoding.\"\"\"\n",
    "    one_hot = np.zeros((y.shape[0], num_classes))\n",
    "    one_hot[np.arange(y.shape[0]), y] = 1\n",
    "    return one_hot\n",
    "\n",
    "# Example usage:\n",
    "num_classes = len(class_list)\n",
    "y_train_onehot = one_hot_encode(y_train, num_classes)\n",
    "\n",
    "# Display one-hot encoding results of ten random sample\n",
    "for _ in range(10):\n",
    "    idx = np.random.randint(0, y_train_onehot.shape[0])\n",
    "\n",
    "    print(f'Sample {idx+1},\\t Class {y_train[idx]}: {y_train_onehot[idx,:]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41f53b71-7b86-4174-ab6e-0c33236054db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Softmax function\n",
    "def softmax(x):\n",
    "    \"\"\"Compute softmax values for each sets of scores in x.\"\"\"\n",
    "    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))\n",
    "    return e_x / e_x.sum(axis=1, keepdims=True)\n",
    "\n",
    "# Cross-entropy loss\n",
    "def cross_entropy_loss(y, y_pred):\n",
    "    \"\"\"Compute cross-entropy loss.\"\"\"\n",
    "    epsilon = 1e-15  # Small value to avoid log(0)\n",
    "    loss = -np.sum(y * np.log(y_pred + epsilon)) / y.shape[0]\n",
    "    return loss\n",
    "\n",
    "def gradient_descent(x, y, learning_rate, num_iterations):\n",
    "    \"\"\"Performs gradient descent optimization.\"\"\"\n",
    "    num_samples, num_features = x.shape\n",
    "    num_classes = y.shape[1]\n",
    "\n",
    "    # Initialize weights and bias\n",
    "    w = np.random.randn(num_features, num_classes)\n",
    "    b = np.zeros(num_classes)\n",
    "    \n",
    "    # Initialize Adagrad accumulators\n",
    "    lr_w = np.zeros(w.shape)\n",
    "    lr_b = 0.0\n",
    "\n",
    "    for i in range(num_iterations):\n",
    "        # Forward pass\n",
    "        scores = np.dot(x, w) + b\n",
    "        y_pred = softmax(scores)\n",
    "\n",
    "        # Compute loss\n",
    "        loss = cross_entropy_loss(y, y_pred)\n",
    "\n",
    "        # Backward pass (compute gradients)\n",
    "        dw = (1 / num_samples) * np.dot(x.T, (y_pred - y))\n",
    "        db = (1 / num_samples) * np.sum(y_pred - y, axis=0)\n",
    "\n",
    "        # Accumulate gradients\n",
    "        lr_w += dw ** 2\n",
    "        lr_b += db ** 2\n",
    "        \n",
    "        # Update parameters\n",
    "        w -= learning_rate / np.sqrt(lr_w) * dw\n",
    "        b -= learning_rate / np.sqrt(lr_b) * db\n",
    "\n",
    "        if i % 100 == 0:\n",
    "            print(f'Iteration {i}, Loss: {loss}')\n",
    "\n",
    "    return w, b\n",
    "\n",
    "def predict(x, w, b):\n",
    "    \"\"\"Predicts class labels for input data.\"\"\"\n",
    "    scores = np.dot(x, w) + b\n",
    "    y_pred = softmax(scores)\n",
    "    return np.argmax(y_pred, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37e5b5a4-ce0c-4f41-99e1-6f29c73196cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_exp = np.array([0, 0, 1])\n",
    "y_exp_pred = np.array([0.22, 0.28, 0.50])\n",
    "print(cross_entropy_loss(y_exp, y_exp_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85f1fea0-ce56-4b53-973c-9e4c115b554e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform gradient descent\n",
    "start_time = time.time()\n",
    "w, b = gradient_descent(x_train, y_train_onehot, learning_rate=1, num_iterations=1000)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f'Training time: {end_time - start_time:.2f} seconds')\n",
    "\n",
    "# Make predictions\n",
    "y_pred = predict(x_test, w, b)\n",
    "\n",
    "# Evaluate accuracy (or other metrics)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(\"Accuracy:\", accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d53f72-e65d-41ed-b1e4-9248385b811d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
